Cast GroupNorm fp16 backward gradient accumulators to fp16 (not bfloat16)#1253
Open
lollinng wants to merge 1 commit into
Open
Cast GroupNorm fp16 backward gradient accumulators to fp16 (not bfloat16)#1253lollinng wants to merge 1 commit into
lollinng wants to merge 1 commit into
Conversation
group_norm_backward set triton_dtype = bfloat16 for every non-fp32 input. For fp16 inputs that means the dW/dB gradient accumulators were rounded to bfloat16 (8 mantissa bits) before being atomic-added into the fp16 DW/DB buffers (10 mantissa bits) -- losing precision for no reason, since the buffers are fp16. Map fp16 -> fp16 so the atomic_add dtype matches the buffer dtype. This is a type-consistency / precision fix, not a crash fix: on the Triton version tested the bf16->fp16 atomic_add is silently coerced rather than erroring. Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
group_norm_backwardpicks the dtype its dW/dB gradient accumulators are cast to before theatomic_addinto theDW/DBbuffers:The
DW/DBbuffers are allocated withW.dtype/B.dtype. For an fp16 model that's an fp16 buffer, buttriton_dtypeisbfloat16— so the fp16 gradients are rounded to bfloat16 (8 mantissa bits) and then atomic-added into an fp16 buffer (10 mantissa bits), losing precision for no reason. bf16 and fp16 are not interchangeable.Fix
Map
fp16 -> tl.float16so the accumulator/atomic dtype matches the buffer dtype. bf16 inputs are unchanged.Testing done (NVIDIA T4)
GroupNorm forward+backward vs
torch.nn.functional.group_norm, comparing the weight gradientdW[0](ref = 4.981):Both pass the tolerance check; the fix moves the fp16 gradient measurably closer to the reference. Honesty note: on the Triton version tested the old
bf16 -> fp16atomic_add is silently coerced rather than erroring, so this is a precision/type-consistency fix, not a crash fix.